#This script embeds the combined news corpus using GloVe
#It also generates the transformation matrices of different weighting sizes for the robustness checks
library(quanteda)
library(text2vec)
library(fst)
library(conText)
library(dplyr)
library(tidylog)
set.seed(123L)

# Define the list of countries
countries <- c("djazairess", "maghress", "masress", "sauress", "turess")

# Script parameters
WINDOW_SIZE <- 6
DIM <- 300
ITERS <- 200

origin_path <- "data/output/combined/"
num_countries <- length(countries)

# Define the total sample sizes as numerics
total_sample_sizes <- c(1e4, 5e4, 1e5, 5e5, 1e6, 1.5e6)

destination_path <- "data/embedding_combined/"
if (!dir.exists(destination_path)) {
  dir.create(destination_path, recursive = TRUE)
}

# Loop through total sample sizes to combine country samples
for (total_size in total_sample_sizes) {
  country_comb_all <- data.frame()
  
  # Calculate the sample size per country (numeric arithmetic)
  size_per_country <- total_size / num_countries
  
  for (country in countries) {
    cat(paste0("Processing ", country, " for total size ", total_size, " ...\n"))
    
    # Build file path (convert total_size to character only for file names)
    file_path <- paste0("data/output/combined/", country, "/", country, "_combined_prepro.rds")
    country_comb <- readRDS(file_path)
    
    # Sample up to size_per_country rows
    country_sample <- country_comb %>% sample_n(min(nrow(country_comb), size_per_country))
    
    country_comb_all <- rbind(country_comb_all, country_sample)
  }
  
  # Save the combined sample, converting total_size to character for the file name
  saveRDS(country_comb_all, file = paste0(destination_path, "country_comb_all_", format(total_size, scientific = FALSE), ".rds"))
}

# Loop through total sample sizes to process tokenization, fcm, and GloVe
for (total_size in total_sample_sizes) {
  cat(paste0("Processing total size ", total_size, " ...\n"))
  
  total_size_chr <- format(total_size, scientific = FALSE)
  
  # Read the combined sample for the specific size
  country_comb_all <- readRDS(file = paste0(destination_path, "country_comb_all_", total_size_chr, ".rds"))
  
  # Tokenize corpus
  country_comb_corpus <- corpus(country_comb_all, text_field = "content")
  toks <- tokens(country_comb_corpus)
  saveRDS(toks, file = paste0(destination_path, "combined_toks", total_size_chr, "30k.rds"))
  
  # Get top 30,000 features
  combined_dfm <- dfm(toks, verbose = TRUE)
  top_feats <- featnames(combined_dfm)[order(-colSums(combined_dfm))[1:30000]]
  
  # Select tokens based on top features with padding
  toks_feats <- tokens_select(toks, top_feats, padding = TRUE)
  saveRDS(toks_feats, file = paste0(destination_path, "combined_toks_feats", total_size_chr, "30k.rds"))
  
  # Construct the feature co-occurrence matrix
  toks_fcm <- fcm(
    toks_feats,
    context = "window",
    window = WINDOW_SIZE,
    count = "frequency",
    tri = FALSE,
    weights = rep(1, WINDOW_SIZE)
  )
  saveRDS(toks_fcm, file = paste0(destination_path, "combined_fcm", total_size_chr, "30k.rds"))
  
  # Estimate GloVe model using text2vec
  glove <- GlobalVectors$new(rank = DIM, x_max = 100, learning_rate = 0.05)
  wv_main <- glove$fit_transform(
    toks_fcm,
    n_iter = ITERS,
    convergence_tol = 1e-3,
    n_threads = parallel::detectCores()
  )
  wv_context <- glove$components
  local_glove <- wv_main + t(wv_context)
  saveRDS(local_glove, file = paste0(destination_path, "combined_local_glove", total_size_chr, "30k.rds"))
  
  local_transform <- compute_transform(x = toks_fcm, pre_trained = local_glove, weighting = 100)
  saveRDS(local_transform, file = paste0(destination_path, "combined_local_transform", total_size_chr, "30k.rds"))
}

# Process transform matrices for a 1.5m embedding
# Note: total_sample_sizes[6] is numeric
total_size <- total_sample_sizes[6]
total_size_chr <- format(total_size, scientific = FALSE)

weightings <- c(1, 10, 50, 100, 500, 1000, 2000, 10000, 50000, 100000)
# For file names, convert weightings to character on the fly
for (weight in weightings) {
  cat("Getting transform matrix for weight ", weight, "\n")
  # Convert weight to numeric explicitly when used in compute_transform
  local_transform <- compute_transform(x = toks_fcm, pre_trained = local_glove, weighting = as.numeric(weight))
  saveRDS(local_transform, file = paste0(destination_path, "combined_local_transform", total_size_chr, "_weight", weight, "30k.rds"))
}

# Get embeddings of varying feature sizes
feature_sizes <- c(1e3, 3e3, 5e3, 10e3, 15e3, 20e3)
# Preserve numeric feature sizes for arithmetic if needed
feature_sizes_char <- sapply(feature_sizes, function(x) format(x, scientific = FALSE))
feature_suffixes <- c("1k", "3k", "5k", "10k", "15k", "20k")
names(feature_sizes) <- feature_suffixes

total_sample_size <- 1.5e6
total_size_numeric <- total_sample_size
total_size_chr <- format(total_size_numeric, scientific = FALSE)

# Read the tokenized corpus for the total size
toks <- readRDS(file = paste0(destination_path, "combined_toks", total_size_chr, "30k.rds"))
combined_dfm <- dfm(toks, verbose = TRUE)

for (feat_size in feature_sizes) {
  cat(paste0("Processing feature size ", feat_size, " ...\n"))
  
  top_feats <- featnames(combined_dfm)[order(-colSums(combined_dfm))[1:as.numeric(feat_size)]]
  toks_feats <- tokens_select(toks, top_feats, padding = TRUE)
  
  # Use the named feature size for file naming
  suffix <- names(feature_sizes)[feature_sizes == feat_size]
  saveRDS(toks_feats, file = paste0(destination_path, "combined_toks_feats", total_size_chr, suffix, ".rds"))
  
  toks_fcm <- fcm(
    toks_feats,
    context = "window",
    window = WINDOW_SIZE,
    count = "frequency",
    tri = FALSE,
    weights = rep(1, WINDOW_SIZE)
  )
  saveRDS(toks_fcm, file = paste0(destination_path, "combined_fcm", total_size_chr, suffix, ".rds"))
  
  glove <- GlobalVectors$new(rank = DIM, x_max = 100, learning_rate = 0.05)
  wv_main <- glove$fit_transform(
    toks_fcm,
    n_iter = ITERS,
    convergence_tol = 1e-3,
    n_threads = parallel::detectCores()
  )
  wv_context <- glove$components
  local_glove <- wv_main + t(wv_context)
  saveRDS(local_glove, file = paste0(destination_path, "combined_local_glove", total_size_chr, suffix, ".rds"))
  
  local_transform <- compute_transform(x = toks_fcm, pre_trained = local_glove, weighting = 100)
  saveRDS(local_transform, file = paste0(destination_path, "combined_local_transform", total_size_chr, suffix, ".rds"))
}
